{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "maxlen = 20\n",
    "max_vocab = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = tf.keras.datasets.imdb.get_word_index()\n",
    "word2idx = {k: (v + 4) for k, v in word2idx.items()}\n",
    "word2idx['<PAD>'] = 0\n",
    "word2idx['<START>'] = 1\n",
    "word2idx['<UNK>'] = 2\n",
    "word2idx['<END>'] = 3\n",
    "idx2word = {i: w for w, i in word2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_X, _), (test_X, _) = tf.contrib.keras.datasets.imdb.load_data(num_words = max_vocab, index_from= 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.concatenate([train_X, test_X])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.concatenate((tf.keras.preprocessing.sequence.pad_sequences(\n",
    "                            X, maxlen, truncating='post', padding='post'),\n",
    "                        tf.keras.preprocessing.sequence.pad_sequences(\n",
    "                            X, maxlen, truncating='pre', padding='post')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_input = X[:]\n",
    "Y_output = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], word2idx['<END>'])], 1)\n",
    "X = X[:, 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100000, 19), (100000, 20), (100000, 20))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, Y_input.shape, Y_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.python.util import nest\n",
    "from tensorflow.contrib.seq2seq.python.ops.beam_search_decoder import _beam_search_step\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "class ModifiedBasicDecoder(tf.contrib.seq2seq.BasicDecoder):\n",
    "    def __init__(self, cell, helper, initial_state, concat_z, output_layer=None):\n",
    "        super().__init__(cell, helper, initial_state, output_layer = output_layer)\n",
    "        self.z = concat_z\n",
    "\n",
    "    def initialize(self, name=None):\n",
    "        (finished, first_inputs, initial_state) =  self._helper.initialize() + (self._initial_state,)\n",
    "        first_inputs = tf.concat([first_inputs, self.z], -1)\n",
    "        return (finished, first_inputs, initial_state)\n",
    "\n",
    "    def step(self, time, inputs, state, name=None):\n",
    "        with tf.name_scope(name, \"BasicDecoderStep\", (time, inputs, state)):\n",
    "            cell_outputs, cell_state = self._cell(inputs, state)\n",
    "        print(self._output_layer)\n",
    "        if self._output_layer is not None:\n",
    "            cell_outputs = self._output_layer(cell_outputs)\n",
    "        print(cell_outputs)\n",
    "        sample_ids = self._helper.sample(\n",
    "            time=time, outputs=cell_outputs, state=cell_state)\n",
    "        (finished, next_inputs, next_state) = self._helper.next_inputs(\n",
    "            time=time,\n",
    "            outputs=cell_outputs,\n",
    "            state=cell_state,\n",
    "            sample_ids=sample_ids)\n",
    "        outputs = tf.contrib.seq2seq.BasicDecoderOutput(cell_outputs, sample_ids)\n",
    "        next_inputs = tf.concat([next_inputs, self.z], -1)\n",
    "        return (outputs, next_state, next_inputs, finished)\n",
    "\n",
    "\n",
    "class ModifiedBeamSearchDecoder(tf.contrib.seq2seq.BeamSearchDecoder):\n",
    "    def __init__(self,\n",
    "                 cell,\n",
    "                 embedding,\n",
    "                 start_tokens,\n",
    "                 end_token,\n",
    "                 initial_state,\n",
    "                 beam_width,\n",
    "                 concat_z,\n",
    "                 output_layer=None,\n",
    "                 length_penalty_weight=0.0):\n",
    "        super().__init__(cell, embedding, start_tokens, end_token, initial_state, beam_width, output_layer, length_penalty_weight)\n",
    "        self.z = concat_z\n",
    "\n",
    "    def initialize(self, name=None):\n",
    "        finished, start_inputs = self._finished, self._start_inputs\n",
    "\n",
    "        start_inputs = tf.concat([start_inputs, self.z], -1)\n",
    "\n",
    "        log_probs = tf.one_hot(  # shape(batch_sz, beam_sz)\n",
    "            tf.zeros([self._batch_size], dtype=tf.int32),\n",
    "            depth=self._beam_width,\n",
    "            on_value=0.0,\n",
    "            off_value=-np.Inf,\n",
    "            dtype=nest.flatten(self._initial_cell_state)[0].dtype)\n",
    "\n",
    "        initial_state = tf.contrib.seq2seq.BeamSearchDecoderState(\n",
    "            cell_state=self._initial_cell_state,\n",
    "            log_probs=log_probs,\n",
    "            finished=finished,\n",
    "            lengths=tf.zeros(\n",
    "                [self._batch_size, self._beam_width], dtype=tf.int64),\n",
    "            accumulated_attention_probs=())\n",
    "\n",
    "        return (finished, start_inputs, initial_state)\n",
    "\n",
    "    def step(self, time, inputs, state, name=None):\n",
    "        batch_size = self._batch_size\n",
    "        beam_width = self._beam_width\n",
    "        end_token = self._end_token\n",
    "        length_penalty_weight = self._length_penalty_weight\n",
    "\n",
    "        with tf.name_scope(name, \"BeamSearchDecoderStep\", (time, inputs, state)):\n",
    "            cell_state = state.cell_state\n",
    "            inputs = nest.map_structure(\n",
    "                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)\n",
    "            cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,\n",
    "                                            self._cell.state_size)\n",
    "            cell_outputs, next_cell_state = self._cell(inputs, cell_state)\n",
    "            cell_outputs = nest.map_structure(\n",
    "                lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)\n",
    "            next_cell_state = nest.map_structure(\n",
    "                self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)\n",
    "            print(self._output_layer)\n",
    "            if self._output_layer is not None:\n",
    "                cell_outputs = self._output_layer(cell_outputs)\n",
    "\n",
    "            beam_search_output, beam_search_state = _beam_search_step(\n",
    "                time=time,\n",
    "                logits=cell_outputs,\n",
    "                next_cell_state=next_cell_state,\n",
    "                beam_state=state,\n",
    "                batch_size=batch_size,\n",
    "                beam_width=beam_width,\n",
    "                end_token=end_token,\n",
    "                length_penalty_weight=length_penalty_weight,\n",
    "                coverage_penalty_weight = 0.0)\n",
    "\n",
    "            finished = beam_search_state.finished\n",
    "            sample_ids = beam_search_output.predicted_ids\n",
    "            next_inputs = tf.cond(\n",
    "                tf.reduce_all(finished), lambda: self._start_inputs,\n",
    "                lambda: self._embedding_fn(sample_ids))\n",
    "\n",
    "            next_inputs = tf.concat([next_inputs, self.z], -1)\n",
    "\n",
    "        return (beam_search_output, beam_search_state, next_inputs, finished)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, dict_size, learning_rate,\n",
    "                beam_size = 15, latent_size = 16, anneal_max = 1.0, anneal_bias = 6000):\n",
    "        \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.GRUCell(size_layer, reuse=reuse)\n",
    "        \n",
    "        def kl_w_fn(global_step):\n",
    "            return anneal_max * tf.sigmoid((10 / anneal_bias) * \\\n",
    "                                           (tf.to_float(global_step) - tf.constant(anneal_bias / 2)))\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y_input = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y_output = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y_input, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        main = tf.strided_slice(self.Y_input, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], word2idx['<START>']), main], 1)\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        x = tf.nn.embedding_lookup(embeddings, self.X)\n",
    "        \n",
    "        _, encoder_state = tf.nn.dynamic_rnn(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), \n",
    "            inputs = x,\n",
    "            sequence_length = self.X_seq_len,\n",
    "            dtype = tf.float32)\n",
    "        encoder_state = encoder_state[-1]\n",
    "        \n",
    "        z_mean = tf.layers.dense(encoder_state, latent_size)\n",
    "        z_var = tf.layers.dense(encoder_state, latent_size)\n",
    "        \n",
    "        posterior = tf.contrib.distributions.MultivariateNormalDiag(z_mean, z_var)\n",
    "        prior = tf.contrib.distributions.MultivariateNormalDiag(tf.zeros_like(z_mean),\n",
    "                                                            tf.ones_like(z_var))\n",
    "        z = posterior.sample()\n",
    "        init_state = tf.layers.dense(z, size_layer, tf.nn.elu)\n",
    "        print(dict_size)\n",
    "        output_proj = tf.layers.Dense(dict_size)\n",
    "        print(output_proj)\n",
    "        decoder_cell = cells()\n",
    "        \n",
    "        helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len)\n",
    "        \n",
    "        decoder = ModifiedBasicDecoder(\n",
    "                cell = decoder_cell,\n",
    "                helper = helper,\n",
    "                initial_state = init_state,\n",
    "                output_layer = output_proj,\n",
    "                concat_z = z)\n",
    "        \n",
    "        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        \n",
    "        self.training_logits = decoder_output.rnn_output\n",
    "        out_dist = tf.distributions.Categorical(self.training_logits)\n",
    "        global_step = tf.Variable(0, trainable=False)\n",
    "        self.out_dist = out_dist.log_prob(self.Y_output)\n",
    "        nll_loss = -tf.reduce_sum(self.out_dist)\n",
    "        self.nll_loss = nll_loss\n",
    "        kl_w = kl_w_fn(global_step)\n",
    "        self.kl_w = kl_w\n",
    "        kl_loss = tf.reduce_sum(tf.distributions.kl_divergence(posterior, prior))\n",
    "        self.kl_loss = kl_loss\n",
    "        self.cost = nll_loss + kl_w * kl_loss\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost, \n",
    "                                                                        global_step = global_step)\n",
    "        \n",
    "        tiled_z = tf.tile(tf.expand_dims(z, 1), [1, beam_size, 1])\n",
    "        decoder = ModifiedBeamSearchDecoder(\n",
    "                cell = decoder_cell,\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([word2idx['<START>']], tf.int32),\n",
    "                                       [batch_size]),\n",
    "                end_token = word2idx['<END>'],\n",
    "                initial_state = tf.contrib.seq2seq.tile_batch(init_state, beam_size),\n",
    "                beam_width = beam_size,\n",
    "                output_layer = output_proj,\n",
    "                concat_z = tiled_z)\n",
    "        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                maximum_iterations = tf.reduce_max(self.X_seq_len),\n",
    "                decoder = decoder)\n",
    "        self.predict_ids = decoder_output.predicted_ids[:, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 1e-3\n",
    "batch_size = 16\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-10-9c757c035d84>:35: MultivariateNormalDiag.__init__ (from tensorflow.contrib.distributions.python.ops.mvn_diag) is deprecated and will be removed after 2018-10-01.\n",
      "Instructions for updating:\n",
      "The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/distributions/python/ops/mvn_diag.py:224: MultivariateNormalLinearOperator.__init__ (from tensorflow.contrib.distributions.python.ops.mvn_linear_operator) is deprecated and will be removed after 2018-10-01.\n",
      "Instructions for updating:\n",
      "The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py:201: AffineLinearOperator.__init__ (from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator) is deprecated and will be removed after 2018-10-01.\n",
      "Instructions for updating:\n",
      "The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py:158: _DistributionShape.__init__ (from tensorflow.contrib.distributions.python.ops.shape) is deprecated and will be removed after 2018-10-01.\n",
      "Instructions for updating:\n",
      "The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.\n",
      "88588\n",
      "<tensorflow.python.layers.core.Dense object at 0x7fd95ccd95f8>\n",
      "<tensorflow.python.layers.core.Dense object at 0x7fd95ccd95f8>\n",
      "Tensor(\"decoder/while/dense/BiasAdd:0\", shape=(?, 88588), dtype=float32)\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/distributions/kullback_leibler.py:98: _kl_brute_force (from tensorflow.contrib.distributions.python.ops.mvn_linear_operator) is deprecated and will be removed after 2018-10-01.\n",
      "Instructions for updating:\n",
      "The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.\n",
      "<tensorflow.python.layers.core.Dense object at 0x7fd95ccd95f8>\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = VAE(size_layer, num_layers, embedded_size, len(word2idx), learning_rate,\n",
    "           latent_size = size_layer)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def word_dropout(x):\n",
    "    is_dropped = np.random.binomial(1, 0.5, x.shape)\n",
    "    fn = np.vectorize(lambda x, k: word2idx['<UNK>'] if (\n",
    "                      k and (x not in range(4))) else x)\n",
    "    return fn(x, is_dropped)\n",
    "\n",
    "def inf_inp(test_strs):\n",
    "    x = [[word2idx.get(w, 2) for w in s.split()] for s in test_strs]\n",
    "    x = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        x, maxlen, truncating='post', padding='post')\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_strings = ['i love this film and i think it is one of the best films',\n",
    "             'this movie is a waste of time and there is no point to watch it']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 14, 120,  15,  23,   6,  14, 105,  13,  10,  32,   8,   5, 119,\n",
       "        109,   0,   0,   0,   0,   0,   0],\n",
       "       [ 15,  21,  10,   7, 438,   8,  59,   6,  51,  10,  58, 214,   9,\n",
       "        107,  13,   0,   0,   0,   0,   0]], dtype=int32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inf_inp(test_strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_input = word_dropout(Y_input[:2])\n",
    "y_output = Y_output[:2]\n",
    "x = X[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[459.78128,\n",
       " array([[-11.39117  , -11.388631 , -11.388212 , -11.389152 , -11.397253 ,\n",
       "         -11.3795185, -11.378984 , -11.402033 , -11.378487 , -11.396691 ,\n",
       "         -11.401232 , -11.407129 , -11.379465 , -11.389895 , -11.420394 ,\n",
       "         -11.384628 , -11.385522 , -11.402591 , -11.404229 , -11.391129 ],\n",
       "        [-11.390028 , -11.358953 , -11.391425 , -11.400146 , -11.411399 ,\n",
       "         -11.379161 , -11.400605 , -11.399052 , -11.405903 , -11.384191 ,\n",
       "         -11.404421 , -11.401221 , -11.382693 , -11.3972645, -11.398771 ,\n",
       "         -11.386691 , -11.408179 , -11.412223 , -11.3879795, -11.381235 ]],\n",
       "       dtype=float32),\n",
       " 455.73785,\n",
       " 604.13904,\n",
       " array([[[ 1.1628436e-02, -8.6111017e-05, -1.8049239e-03, ...,\n",
       "          -3.8907779e-04, -1.8487602e-02,  1.6926985e-02],\n",
       "         [ 1.6618120e-02, -3.0542999e-03, -5.1304079e-03, ...,\n",
       "          -1.0260832e-03, -1.5437857e-02,  1.8255688e-02],\n",
       "         [ 1.4503996e-02,  1.7808471e-04, -1.0922158e-02, ...,\n",
       "           3.9447471e-03, -9.6551636e-03,  8.8451887e-03],\n",
       "         ...,\n",
       "         [ 3.9398419e-03,  6.9508231e-03, -9.5413383e-03, ...,\n",
       "           2.1544114e-02,  5.4113334e-05, -1.8743049e-02],\n",
       "         [ 3.5744030e-03,  1.0301781e-02, -1.1494145e-02, ...,\n",
       "           2.4511881e-02, -1.8682680e-04, -2.3151923e-02],\n",
       "         [ 8.3817504e-03, -3.0663761e-04, -1.6889125e-02, ...,\n",
       "           2.4420906e-02, -4.5866393e-03, -2.0410130e-02]],\n",
       " \n",
       "        [[ 1.3329982e-02,  7.2369212e-04, -3.7250356e-03, ...,\n",
       "           5.7532005e-03, -4.8707668e-03,  3.1372299e-04],\n",
       "         [ 1.7307898e-02, -1.5590533e-03, -6.9943559e-03, ...,\n",
       "           2.9455894e-03, -6.4655514e-03,  4.4756061e-03],\n",
       "         [ 1.3527648e-02,  2.1658097e-03, -1.2396346e-02, ...,\n",
       "           6.1243912e-03, -2.4111450e-03, -2.4033338e-03],\n",
       "         ...,\n",
       "         [ 1.1664398e-03, -5.4866187e-03, -7.7285739e-03, ...,\n",
       "           1.4258313e-02,  3.3176029e-03, -2.2803556e-02],\n",
       "         [ 5.7728449e-04, -1.1413544e-03, -1.3149259e-02, ...,\n",
       "           1.7795686e-02,  4.4947616e-03, -2.6228081e-02],\n",
       "         [ 9.4815157e-05, -6.2796208e-03, -1.9580342e-02, ...,\n",
       "           1.7256645e-02,  2.9276810e-03, -1.9821122e-02]]], dtype=float32)]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run([model.cost, model.out_dist, model.nll_loss, model.kl_loss, model.training_logits],\n",
    "         feed_dict = {model.X: x, model.Y_input: y_input,\n",
    "                      model.Y_output: y_output})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"enthusast maisie casket casket casket casket casket evasive evasive beaux beaux glyn glyn pigeon 'f\""
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r_aug = sess.run(model.predict_ids, feed_dict = {model.X: inf_inp(test_strings)})[0]\n",
    "' '.join([idx2word[r] for r in r_aug])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.26it/s, cost=3.66e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, average loss 3930.666338\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: this a movie you it a of movie you to it the of time this\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:14<00:00,  4.26it/s, cost=3.7e+3] \n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, average loss 3849.483341\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: to it i it a movie i it a movie i it it the of\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.25it/s, cost=3.69e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, average loss 3807.555884\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i recommend movie i it the of i it i it i it the of\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.24it/s, cost=3.64e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, average loss 3771.349666\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i saw movie i it the was i it i it it a of and\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:14<00:00,  4.25it/s, cost=3.59e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, average loss 3741.504500\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i to it the i this was the movie i it have seen a of\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.25it/s, cost=3.61e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, average loss 3718.190249\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: to this i this is one the i ever it i it have the of\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.25it/s, cost=3.58e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, average loss 3697.350334\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i to it it a movie i it to it i it a out 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.25it/s, cost=3.54e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, average loss 3680.469965\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i this was a movie the of film i recommend to it anyone the of\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.26it/s, cost=3.58e+3]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, average loss 3664.853960\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: that would been to it the is a movie i it have seen a of\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [12:15<00:00,  4.25it/s, cost=3.54e+3]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, average loss 3650.547580\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: film i this is a movie is a movie i recommend to it anyone likes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for e in range(epoch):\n",
    "    pbar = tqdm(\n",
    "        range(0, len(X), batch_size), desc = 'minibatch loop')\n",
    "    cost = 0\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(X))\n",
    "        y_input = word_dropout(Y_input[i: index])\n",
    "        y_output = Y_output[i: index]\n",
    "        x = X[i: index]\n",
    "        c, _ = sess.run([model.cost, model.optimizer],\n",
    "         feed_dict = {model.X: x, model.Y_input: y_input,\n",
    "                      model.Y_output: y_output})\n",
    "        cost += c\n",
    "        pbar.set_postfix(cost = c)\n",
    "    cost /= (len(X) / batch_size)\n",
    "    r_aug = sess.run(model.predict_ids, feed_dict = {model.X: inf_inp(test_strings)})[0]\n",
    "    print('epoch %d, average loss %f'%(e + 1, cost))\n",
    "    print('real string: %s'%(test_strings[0]))\n",
    "    print('augmented string: %s'%(' '.join([idx2word[r] for r in r_aug])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "augmented string: there no but the was too but was i it have the and it a\n"
     ]
    }
   ],
   "source": [
    "r_aug = sess.run(model.predict_ids, feed_dict = {model.X: inf_inp(test_strings)})[0]\n",
    "print('augmented string: %s'%(' '.join([idx2word[r] for r in r_aug])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
